import java.util.ArrayList;
import java.util.List;

public class Solution2049 {

    private List<Integer>[] getChildrenInfo(int[] parents){
        int nodeCount = parents.length;
        List<Integer>[] childrenInfo = new List[nodeCount];
        for (int i = 0; i < nodeCount; i++) {
            childrenInfo[i] = new ArrayList<>();
        }
        for (int i = 1; i < nodeCount; i++){
            childrenInfo[parents[i]].add(i);
        }
        return childrenInfo;
    }


    private int fillCountInfo(int step, List<Integer>[] childrenInfo, int[][] childrenCountInfo){
        int[] nodeCountInfo = childrenCountInfo[step];
        List<Integer> nodeInfo = childrenInfo[step];
        for (int i = 0; i < nodeInfo.size(); i++){
            nodeCountInfo[i] = fillCountInfo(nodeInfo.get(i), childrenInfo, childrenCountInfo);
        }
        return nodeCountInfo[0] + nodeCountInfo[1] + 1;
    }


    public int countHighestScoreNodes(int[] parents) {
        int nodeCount = parents.length;
        List<Integer>[] childrenInfo = getChildrenInfo(parents);
        int[][] childrenCountInfo = new int[nodeCount][2];
        fillCountInfo(0, childrenInfo, childrenCountInfo);
        long highestScore = 0;
        int highestCount = 0;
        for (int[] nodeCountInfo : childrenCountInfo){
            int childCount1 = nodeCountInfo[0], childCount2 = nodeCountInfo[1];
            long tmpScore = Math.max(childCount1, 1) * Math.max(childCount2, 1) * Math.max(nodeCount - childCount1 - childCount2 - 1, 1);
            if(tmpScore > highestScore){
                highestScore = tmpScore;
                highestCount = 0;
            }
            if(tmpScore == highestScore){
                highestCount++;
            }
        }
        return highestCount;
    }

    public static void main(String[] args) {
        Solution2049 s = new Solution2049();
        s.countHighestScoreNodes(new int[]{-1,2,0,2,0});
    }
}
